{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据的读取和预处理\n",
    "目标：产生源语言与目标语言的one-hot的tensor表示\n",
    "\n",
    "tensor([[211, 212, 594,   5,   1]])\n",
    "tensor([[130, 125,  58,   4,   1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import os\n",
    "import random\n",
    "import torch\n",
    "import unicodedata\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "SOS_token = 0                                   # 词汇表中的起止符\n",
    "EOS_token = 1\n",
    "MAX_LENGTH = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义词汇表类，保存词汇表\n",
    "class Language(object):                         # 普通类继承于object\n",
    "    def __init__(self, language):\n",
    "        \"\"\"\n",
    "        language: the name of the language\n",
    "        \"\"\"\n",
    "        self.language = language\n",
    "        self.word2index = {'SOS':0,'EOS':1}  # the dict of the word with respect to index\n",
    "        self.word2count = {}                 # the number of occurrences \n",
    "        self.index2word = {0:'SOS', 1:'EOS'} # index with respect to word,special token included\n",
    "        self.n_words = 2                     # total number of the vocabulary\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):     # split sentence by ' '\n",
    "            self.addWord(word)\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:      # !!! not in dictionary's keys\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn')\n",
    "\n",
    "# Lowercase, trim, and remove non-letter characters\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    s = re.sub(r\"([.!?])\", r\" \\1\", s)       # add space between character and interpunction\n",
    "    s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)   # only match the a-z/A-Z/.!?\n",
    "    return s\n",
    "\n",
    "\n",
    "def readLangs(lang1, lang2, reverse=False):\n",
    "    print(\"Reading lines...\")\n",
    "\n",
    "    # Read the file and split into lines\n",
    "    file_path = os.path.join('../..', 'data', 'translation', '%s-%s.txt' % (lang1, lang2))\n",
    "    lines = open(file_path, encoding='utf-8').read().strip().split('\\n')\n",
    "\n",
    "    # Split every line into pairs and normalize\n",
    "    pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
    "\n",
    "    # Reverse pairs, make Lang instances\n",
    "    if reverse:\n",
    "        pairs = [list(reversed(p)) for p in pairs]\n",
    "        input_lang = Language(lang2)\n",
    "        output_lang = Language(lang1)\n",
    "    else:\n",
    "        input_lang = Language(lang1)\n",
    "        output_lang = Language(lang2)\n",
    "\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "\n",
    "eng_prefixes = (\"i am \", \"i m \", \"he is\", \"he s \", \"she is\", \"she s\",\n",
    "                \"you are\", \"you re \", \"we are\", \"we re \", \"they are\",\n",
    "                \"they re \")\n",
    "\n",
    "\n",
    "def filterPair(p):       # select the short sentence and which is started with the given prefixes\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "           len(p[1].split(' ')) < MAX_LENGTH and \\\n",
    "           p[1].startswith(eng_prefixes)\n",
    "\n",
    "\n",
    "def filterPairs(pairs):  # filter all pairs\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "\n",
    "def prepareData(lang1, lang2, reverse=False):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
    "    print(\"Read %s sentence pairs\" % len(pairs))\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
    "    print(\"Counting words...\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])               # add word to vocabulary\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(\"Counted words:\")\n",
    "    print(input_lang.language, input_lang.n_words)\n",
    "    print(output_lang.language, output_lang.n_words)\n",
    "    print(random.choice(pairs))\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "\n",
    "def indexesFromSentence(lang, sentence):              # sentence to indexs\n",
    "    return [lang.word2index[word] for word in sentence.split(' ')]\n",
    "\n",
    "\n",
    "def tensorFromSentence(lang, sentence):               # sentence to tensor\n",
    "    indexes = indexesFromSentence(lang, sentence)\n",
    "    indexes.append(EOS_token)\n",
    "    result = torch.LongTensor(indexes)\n",
    "    return result\n",
    "\n",
    "\n",
    "def tensorFromPair(input_lang, output_lang, pair):    # pair to tensor\n",
    "    input_tensor = tensorFromSentence(input_lang, pair[0])\n",
    "    target_tensor = tensorFromSentence(output_lang, pair[1])\n",
    "    return input_tensor, target_tensor\n",
    "\n",
    "\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self, dataload=prepareData, lang=['eng', 'fra']):\n",
    "        \"\"\"\n",
    "        dataload: the dataset for dataloader\n",
    "        lang: the name of the source and target language \n",
    "        \"\"\"\n",
    "        self.input_lang, self.output_lang, self.pairs = dataload(\n",
    "            lang[0], lang[1], reverse=True)\n",
    "        self.input_lang_words = self.input_lang.n_words\n",
    "        self.output_lang_words = self.output_lang.n_words\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return tensorFromPair(self.input_lang, self.output_lang,\n",
    "                              self.pairs[index])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading lines...\n",
      "Read 135842 sentence pairs\n",
      "Trimmed to 10853 sentence pairs\n",
      "Counting words...\n",
      "Counted words:\n",
      "fra 4489\n",
      "eng 2925\n",
      "['je ne suis pas vegetarienne .', 'i m not a vegetarian .']\n"
     ]
    }
   ],
   "source": [
    "da = TextDataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "lang_dataloader = DataLoader(da, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  24,   10, 2092,  833,  454, 3613,    5,    1]])\n",
      "tensor([[  14,   40, 2253,  539,  638,  295, 1529,    4,    1]])\n",
      "torch.Size([1, 8])\n"
     ]
    }
   ],
   "source": [
    "for data in lang_dataloader:\n",
    "    print(data[0])\n",
    "    print(data[1])\n",
    "    print(data[0].shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结：\n",
    "生成的向量的长度不一致，所以无法按批次读取多个句子，除非进行padding操作，统一句子长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
